import numpy as np
import quantecon as qe
from scipy.interpolate import interp1d
from scipy.optimize import minimize
from scipy.optimize import broyden1
from scipy.optimize import brentq
from scipy.stats import norm
import matplotlib.pyplot as plt
from numpy.linalg import inv
from mpl_toolkits import mplot3d
import support_functions as sp
import pandas as pd
from numpy.polynomial.legendre import leggauss
import pickle
import time
import os
from scipy.optimize import minimize
from scipy.optimize import root
import seaborn as sns

# Parameter values
beta = 0.99
inv_frisch = 5
gamma = 1
phi_y = 0.5 / 4
phi_pi = 1.5
SIGMA_CES = 5
DRS_ALPHA = 1

menu_cost = 0.02

# Shock parameters
persistence_w_shock = 0.5
initial_w_shock = 0.02
final_shock = initial_w_shock/persistence_w_shock

# Running parameters
backward_iter_max = 10000
backward_iter_thresh = 1E-3
backward_iter_ss_thresh = 1E-6
MIT_shock_iter_max = 100
MIT_iter_thresh = 1E-8
foreward_iter_max = 1000
foreward_iter_thresh = 1E-2

# Toggle parameters to run the system
USE_CES = False
DEMAND_SHOCKS = False
ROUWENHURST_SYM = False
MASK = True

# Generate model primitives from Darwinian data
shares,mu,rho,A,y = sp.get_key_darwinian_data()
mu = sp.generate_markups(rho, shares, 1.15)
A = np.ones(len(mu))
for i in range(1,len(A)):
	A[i] = np.exp(np.log(A[i-1]) + (mu[i-1]-1)/rho[i-1] * (np.log(shares[i]) - np.log(shares[i-1])))

p = mu / A

def unit_root_transition_process(n, q):
	T_mat = np.zeros((n, n))
	mask = (np.abs(np.add.outer(np.arange(n), -np.arange(n))) < 3).astype(int)
	for i in range(len(q)):
		mask_i = (np.abs(np.add.outer(np.arange(n), -np.arange(n))) < i+2).astype(int) - (np.abs(np.add.outer(np.arange(n), -np.arange(n))) < i+1).astype(int)
		T_mat += mask_i * q[i]
	T_mat += np.diag(1 - T_mat.sum(axis=1))
	return T_mat

n_A = 71 
dlogA_spacing = 0.02
dlogA_transition_prob = 0.15
dlogA_grid = np.linspace(-dlogA_spacing*(n_A-1)/2, dlogA_spacing*(n_A-1)/2, n_A)
A_grid = np.exp(dlogA_grid)
A_grid = A_grid/A_grid[0]
Pi_dlogA_symmetric = unit_root_transition_process(n_A, [0.15, 0.08, 0.04])
wt = np.ones(n_A) / (n_A + 0.000001)
if ROUWENHURST_SYM:
	n_A = 21
	persistence_dlogA = 0.98;
	sigma_dlogA = 0.06;
	dlogA_markov_chain = qe.markov.approximation.rouwenhorst(n_A, 0, sigma_dlogA, persistence_dlogA)
	Pi_dlogA_symmetric = dlogA_markov_chain.P
	A_grid = np.exp(dlogA_markov_chain._state_values)
	A_grid = A_grid/A_grid[0]
	wt = dlogA_markov_chain.stationary_distributions[0]

MASK_LO_PCT = 0.1
MASK_HI_PCT = 0.9
MASK_LO = int(MASK_LO_PCT*n_A)
MASK_HI = int(MASK_HI_PCT*n_A)
LARGE_CUTOFF = int(0.7*n_A)

# Price grid
dlogp_min = -0.5
dlogp_max = 0.5
dlogp_grid = np.arange(dlogp_min, dlogp_max, 0.002)
n_dlogp = len(dlogp_grid)
dlogp_grid = np.arange(-1.3, 1.5, 0.001)

n_dlogp = len(dlogp_grid)

# Demand shocks
n_dlogB = 1
Pi_dlogB = np.array([[1]])
dlogB_grid = np.array([0])

# Finally, set up the whole state grid
dlogp_mesh,A_mesh,dlogB_mesh = np.meshgrid(dlogp_grid, A_grid, dlogB_grid, indexing='ij')

DEMAND_CURVE_SAVE_NAME = './data/data_temp/demand_curve_extended.dta'

def load_demand_curve():
	if os.path.exists(DEMAND_CURVE_SAVE_NAME):
		demand_curve_df = pd.read_stata(DEMAND_CURVE_SAVE_NAME)
		demand_curve = interp1d(np.exp(demand_curve_df['logp']), demand_curve_df['y'])
		y_arr = demand_curve_df['y'].to_numpy()
		Upsilon_prime_arr = np.exp(demand_curve_df['logp']).to_numpy()
		Upsilon_arr = np.append([0], np.cumsum(Upsilon_prime_arr[:-1] * np.diff(y_arr)))
		Upsilon_toconst = interp1d(y_arr, Upsilon_arr)
	else:
		second_order_approx_atind = lambda logy,ind: np.log(p[ind]) + -1/elast[ind] * (logy-np.log(y[ind])) - 1/(elast[ind] * mu[ind]) * (1-rho[ind])/rho[ind] * (logy-np.log(y[ind]))**2
		second_order_approx_p_atind = lambda logp,ind: np.log(y[ind]) + mu[ind]*rho[ind]/(1-rho[ind]) * (-0.5 + np.sqrt(max(0, 0.25 - elast[ind]/mu[ind] * (1-rho[ind])/rho[ind] * (logp-np.log(p[ind])))))
		def get_second_order_approx_p(ind):
			return lambda logp: second_order_approx_p_atind(logp,ind)
		extrapolate_lowprices = get_second_order_approx_p(-1)
		extrapolate_highprices = get_second_order_approx_p(0)
		# Use first-order approximation for extrapolating high prices
		extrapolate_highprices = lambda logp: np.log(y[0]) - elast[0]*(logp-np.log(p[0]))
		log_demand_curve = interp1d(np.log(p), np.log(y))
		def demand_curve_comp(relative_p):
			if relative_p > p.max():
				return np.exp(extrapolate_highprices(np.log(relative_p)))
			if relative_p >= p.min():
				return np.exp(log_demand_curve(np.log(relative_p)))
			else:
				return np.exp(extrapolate_lowprices(np.log(relative_p)))
		demand_curve_comp = np.vectorize(demand_curve_comp)
		logp_expanded = np.linspace(10, -10, 1000000)
		demand_curve = interp1d(np.exp(logp_expanded), demand_curve_comp(np.exp(logp_expanded)))
		pd.DataFrame({'logp': logp_expanded, 'y': demand_curve_comp(np.exp(logp_expanded))}).to_stata(DEMAND_CURVE_SAVE_NAME, write_index=False)
		y_arr = demand_curve(np.exp(logp_expanded))
		Upsilon_prime_arr = np.exp(logp_expanded)
		Upsilon_arr = np.append([0], np.cumsum(Upsilon_prime_arr[:-1] * np.diff(y_arr)))
		Upsilon_toconst = interp1d(y_arr, Upsilon_arr)
	return demand_curve,Upsilon_toconst

demand_curve,Upsilon_toconst = load_demand_curve()
demand_curve_CES = lambda p: np.exp(-SIGMA_CES * np.log(p))
Upsilon_toconst_CES = lambda y: y**((SIGMA_CES-1)/(SIGMA_CES))

if USE_CES:
	demand_curve = demand_curve_CES
	Upsilon_toconst = Upsilon_toconst_CES

def profit(dlogp, realized_A, dlogB, dlogP, dlogw, dlogY, p_ss, P_ss, w_ss, Y_ss):
	realized_p = p_ss*np.exp(dlogp)
	realized_w = w_ss*np.exp(dlogw)
	realized_P = P_ss*np.exp(dlogP)
	realized_Y = Y_ss*np.exp(dlogY)
	quality_adj_p = realized_p/np.exp(dlogB)
	output = realized_Y*demand_curve(quality_adj_p/realized_P)
	return realized_p*output - realized_w*(output/realized_A)**(1/DRS_ALPHA)

def revenue(dlogp, realized_A, dlogB, dlogP, dlogw, dlogY, p_ss, P_ss, w_ss, Y_ss):
	realized_p = p_ss*np.exp(dlogp)
	realized_w = w_ss*np.exp(dlogw)
	realized_P = P_ss*np.exp(dlogP)
	realized_Y = Y_ss*np.exp(dlogY)
	quality_adj_p = realized_p/np.exp(dlogB)
	return realized_Y*demand_curve(quality_adj_p/realized_P) * realized_p

def generate_ss_value_function(P_ss, w_ss, Y_ss, init_value=None):
	profit_mat = profit(dlogp_mesh, A_mesh, dlogB_mesh, 0, 0, 0, 1, P_ss, w_ss, Y_ss)
	revenue_mat = revenue(dlogp_mesh, A_mesh, dlogB_mesh, 0, 0, 0, 1, P_ss, w_ss, Y_ss)
	ss_dlogp_index = np.argmax(profit_mat[:,:,np.searchsorted(dlogB_grid,0)], axis=0)
	ss_revenue = revenue(dlogp_grid, 1, 0, 0, 0, 0, 1, P_ss, w_ss, Y_ss)[ss_dlogp_index]
	ss_revenue_3d = np.tile(ss_revenue[None,:,None], (dlogp_mesh.shape[0], 1, dlogp_mesh.shape[2]))
	V_curr = np.zeros(profit_mat.shape)
	policy_last = np.zeros(profit_mat.shape)
	Pi_full = np.kron(Pi, Pi_dlogB)
	if init_value is None:
		V_new = profit_mat / (1-beta)
	else:
		V_new = init_value.copy()
	policy = dlogp_mesh.copy()
	for backwater_iter_ind in range(backward_iter_max):
		if np.max(abs(V_curr - V_new)) <= backward_iter_ss_thresh:
			break
		print(np.max(abs(V_curr - V_new)), (policy_last != policy).sum())
		policy_last = policy.copy()
		V_curr = V_new.copy()
		# Subtract from max value to keep definition in case of ties
		V_curr_normalized = V_curr - np.amax(V_curr, axis=0)
		policy = dlogp_mesh.copy()
		# Generate value with and without normalization (true value vs. used for comparison)
		value_with_menu_costs = profit_mat + beta * np.reshape(np.reshape(V_curr, (n_dlogp, n_A*n_dlogB), order='C') @ Pi_full.T, (n_dlogp, n_A, n_dlogB), order='C') - menu_cost*revenue_mat
		value_without_menu_costs = profit_mat + beta * np.reshape(np.reshape(V_curr, (n_dlogp, n_A*n_dlogB), order='C') @ Pi_full.T, (n_dlogp, n_A, n_dlogB), order='C')
		value_with_menu_costs_normalized = profit_mat + beta * np.reshape(np.reshape(V_curr_normalized, (n_dlogp, n_A*n_dlogB), order='C') @ Pi_full.T, (n_dlogp, n_A, n_dlogB), order='C') - menu_cost*revenue_mat
		value_without_menu_costs_normalized = profit_mat + beta * np.reshape(np.reshape(V_curr_normalized, (n_dlogp, n_A*n_dlogB), order='C') @ Pi_full.T, (n_dlogp, n_A, n_dlogB), order='C')
		dlogp_max_policy = dlogp_grid[np.argmax(value_with_menu_costs_normalized, axis=0)]
		dlogp_max_value = np.amax(value_with_menu_costs, axis=0)
		dlogp_max_value_normalized = np.amax(value_with_menu_costs_normalized, axis=0)
		dlogp_max_policy_3d = np.tile(dlogp_max_policy, (n_dlogp,1,1))
		dlogp_max_value_3d = np.tile(dlogp_max_value, (n_dlogp,1,1))
		dlogp_max_value_normalized_3d = np.tile(dlogp_max_value_normalized, (n_dlogp,1,1))
		V_new = value_without_menu_costs.copy()
		V_new[value_without_menu_costs_normalized < dlogp_max_value_normalized] = dlogp_max_value_3d[value_without_menu_costs_normalized < dlogp_max_value_normalized]
		policy[value_without_menu_costs_normalized < dlogp_max_value_normalized] = dlogp_max_policy_3d[value_without_menu_costs_normalized < dlogp_max_value_normalized]
	return V_new,policy

def calculate_other_vars_dist(ss_dist, P_ss, w_ss, Y_ss):
	ss_dist_masked = ss_dist.copy()
	if MASK:
		ss_dist_masked[:,:MASK_LO,:] = 0 
		ss_dist_masked[:,MASK_HI:,:] = 0 
	p_grid_dist = np.exp(np.log(1) + dlogp_mesh)
	A_grid_dist = A_mesh.copy()
	B_grid_dist = np.exp(dlogB_mesh)
	y_grid_dist = Y_ss*demand_curve(p_grid_dist/P_ss * 1/B_grid_dist)
	l_grid_dist = y_grid_dist / A_grid_dist
	mu_grid_dist = p_grid_dist * A_grid_dist / w_ss
	sale_grid_dist = p_grid_dist * y_grid_dist
	Ups_grid_dist = B_grid_dist * Upsilon_toconst(y_grid_dist/Y_ss)
	# Calculate other vars
	P_Y_distgen = (sale_grid_dist * ss_dist_masked).sum()
	L_distgen = (l_grid_dist * ss_dist_masked).sum()
	A_distgen = Y_ss / L_distgen
	mu_bar_distgen = (P_Y_distgen / ((sale_grid_dist / mu_grid_dist) * ss_dist_masked).sum())
	alpha_distgen = (Y_ss)**(1-gamma) / (mu_bar_distgen * L_distgen**(1 + inv_frisch))
	Upsilon_sum_distgen = (Ups_grid_dist * ss_dist_masked).sum()
	Upsilon_C_distgen = 1 - Upsilon_sum_distgen
	return P_Y_distgen,L_distgen,A_distgen,mu_bar_distgen,alpha_distgen,Upsilon_C_distgen


def calculate_ss_distribution(policy):
	N_FIRMS = 100000
	# p_grid_dist = np.exp(np.log(1) + dlogp_mesh)
	# A_grid_dist = A_mesh.copy()
	# B_grid_dist = np.exp(dlogB_mesh)
	# y_grid_dist = Y_ss*demand_curve(p_grid_dist/P_ss * 1/B_grid_dist)
	# l_grid_dist = y_grid_dist / A_grid_dist
	# mu_grid_dist = p_grid_dist * A_grid_dist / w_ss
	# sale_grid_dist = p_grid_dist * y_grid_dist
	# Ups_grid_dist = B_grid_dist * Upsilon_toconst(y_grid_dist/Y_ss)
	ss_dist_new = np.zeros(A_mesh.shape)
	ss_dist_new[np.searchsorted(dlogp_grid, 0), :, np.searchsorted(dlogB_grid, 0)] = wt * N_FIRMS
	policy_ind = np.searchsorted(dlogp_grid, policy)
	Pi_full = np.kron(Pi, Pi_dlogB)
	ss_dist_curr = np.zeros(A_mesh.shape)
	for t in range(foreward_iter_max):
		print(np.abs(ss_dist_new - ss_dist_curr).max(), ss_dist_new.sum())
		if np.abs(ss_dist_new - ss_dist_curr).max() < foreward_iter_thresh:
			break
		ss_dist_curr = ss_dist_new.copy()
		ss_dist_new = np.zeros(A_mesh.shape)
		# Transition to new 
		ss_dist_old = np.reshape(np.reshape(ss_dist_curr, (n_dlogp, n_A*n_dlogB), order='C') @ Pi_full, (n_dlogp, n_A, n_dlogB), order='C')
		for p in range(n_dlogp):
			for a in range(n_A):
				for b in range(n_dlogB):
					ss_dist_new[policy_ind[p,a,b],a,b] += ss_dist_old[p,a,b]
	return ss_dist_new

def plot_policy_function(policy):
	fig = plt.figure()
	ax = plt.axes(projection='3d')
	ax.plot_surface(A_mesh[:,:,np.searchsorted(dlogB_grid,0)], dlogp_mesh[:,:,np.searchsorted(dlogB_grid,0)], policy[:,:,np.searchsorted(dlogB_grid,0)], rstride=1, cstride=1, cmap='viridis', edgecolor='none')
	ax.set_xlabel('A')
	ax.set_ylabel('dlogp')
	ax.set_zlabel('Policy')

# def plot_price_changes(dlogp_val):
# 	T = dlogp_val.shape[1]
# 	plt.plot(range(T), dlogp_val.T)
# 	plt.figure()
# 	price_changes = np.diff(dlogp_val,axis=1)
# 	price_changes_flat = price_changes.ravel()
# 	price_changes_flat = price_changes_flat[price_changes_flat != 0]
# 	plt.hist(price_changes_flat, bins=12, density=True)
# 	plt.xlabel('Log price change')
# 	plt.ylabel('Frequency')

def generate_ss_values_dist(P_ss, w_ss, Y_ss):
	value_fn,policy = generate_ss_value_function(P_ss, w_ss, Y_ss)
	ss_dist = calculate_ss_distribution(policy)
	return policy,value_fn,ss_dist,calculate_other_vars_dist(ss_dist, P_ss, w_ss, Y_ss)


def backward_iterate_valuefn(value_fn_i, policy_i, dlogP_path, dlogw_path, dlogY_path, P_ss, w_ss, Y_ss):
	T = len(dlogw_path)
	value_fn_path = [value_fn_i.copy() for t in range(T)]
	policy_path = [dlogp_mesh.copy() for t in range(T)]
	Pi_full = np.kron(Pi, Pi_dlogB)
	profit_mat = profit(dlogp_mesh, A_mesh, dlogB_mesh, 0, 0, 0, 1, P_ss, w_ss, Y_ss)
	revenue_mat = revenue(dlogp_mesh, A_mesh, dlogB_mesh, 0, 0, 0, 1, P_ss, w_ss, Y_ss)
	ss_dlogp_index = np.argmax(profit_mat[:,:,np.searchsorted(dlogB_grid,0)], axis=0)
	ss_revenue = revenue(dlogp_grid, 1, 0, 0, 0, 0, 1, P_ss, w_ss, Y_ss)[ss_dlogp_index]
	ss_revenue_3d = np.tile(ss_revenue[None,:,None], (dlogp_mesh.shape[0], 1, dlogp_mesh.shape[2]))
	for t in range(T)[::-1]:
		if t+1 == T:
			value_fn_next = value_fn_i 
		else:
			value_fn_next = value_fn_path[t+1]
		v_next_normalized = value_fn_next - np.amax(value_fn_next, axis=0)
		profit_mat = profit(dlogp_mesh, A_mesh, dlogB_mesh, dlogP_path[t], dlogw_path[t], dlogY_path[t], 1, P_ss, w_ss, Y_ss)
		#profit_mat = profit(dlogp_mesh, A_mesh, dlogB_mesh, dlogw_path[t], dlogw_path[t], Y_ss, 1, P_ss, w_ss, Y_ss)
		revenue_mat = revenue(dlogp_mesh, A_mesh, dlogB_mesh, dlogP_path[t], dlogw_path[t], dlogY_path[t], 1, P_ss, w_ss, Y_ss)
		#revenue_mat = ss_revenue_3d * np.exp(np.log(w_ss) + dlogP_path[t])
		# Generate value with and without normalization (true value vs. used for comparison)
		value_with_menu_costs = profit_mat + beta * np.reshape(np.reshape(value_fn_next, (n_dlogp, n_A*n_dlogB), order='C') @ Pi_full.T, (n_dlogp, n_A, n_dlogB), order='C') - menu_cost*revenue_mat
		value_without_menu_costs = profit_mat + beta * np.reshape(np.reshape(value_fn_next, (n_dlogp, n_A*n_dlogB), order='C') @ Pi_full.T, (n_dlogp, n_A, n_dlogB), order='C')
		value_with_menu_costs_normalized = profit_mat + beta * np.reshape(np.reshape(v_next_normalized, (n_dlogp, n_A*n_dlogB), order='C') @ Pi_full.T, (n_dlogp, n_A, n_dlogB), order='C') - menu_cost*revenue_mat
		value_without_menu_costs_normalized = profit_mat + beta * np.reshape(np.reshape(v_next_normalized, (n_dlogp, n_A*n_dlogB), order='C') @ Pi_full.T, (n_dlogp, n_A, n_dlogB), order='C')
		dlogp_max_policy = dlogp_grid[np.argmax(value_with_menu_costs_normalized, axis=0)]
		dlogp_max_value = np.amax(value_with_menu_costs, axis=0)
		dlogp_max_value_normalized = np.amax(value_with_menu_costs_normalized, axis=0)
		dlogp_max_policy_3d = np.tile(dlogp_max_policy, (n_dlogp,1,1))
		dlogp_max_value_3d = np.tile(dlogp_max_value, (n_dlogp,1,1))
		dlogp_max_value_normalized_3d = np.tile(dlogp_max_value_normalized, (n_dlogp,1,1))
		value_fn_path[t] = value_without_menu_costs.copy()
		value_fn_path[t][value_without_menu_costs_normalized < dlogp_max_value_normalized] = dlogp_max_value_3d[value_without_menu_costs_normalized < dlogp_max_value_normalized]
		policy_path[t][value_without_menu_costs_normalized < dlogp_max_value_normalized] = dlogp_max_policy_3d[value_without_menu_costs_normalized < dlogp_max_value_normalized]
	return np.array(value_fn_path),np.array(policy_path)

def solve_for_P_dist(A_p_dist, P_ss, Ups_C_ss):
	A_p_dist_masked = A_p_dist.copy()
	if MASK:
		A_p_dist_masked[:,:MASK_LO,:] = 0 
		A_p_dist_masked[:,MASK_HI:,:] = 0 
	p_grid_dist = np.exp(np.log(1) + dlogp_mesh)
	A_grid_dist = A_mesh.copy()
	B_grid_dist = np.exp(dlogB_mesh)
	def try_P(P_ss_try):
		y_grid_dist = demand_curve(p_grid_dist/P_ss_try * 1/B_grid_dist)
		Ups_grid_dist = B_grid_dist * Upsilon_toconst(y_grid_dist)
		return (Ups_grid_dist * A_p_dist_masked).sum() - (1 - Ups_C_ss)
	return root(try_P, P_ss).x

def solve_for_P_dist_path(path_dist, P_ss, Ups_C_ss):
	T = len(path_dist)
	P_generated = np.zeros(T)
	for t in range(T):
		P_generated[t] = solve_for_P_dist(path_dist[t], P_ss, Ups_C_ss)
	return P_generated

def calculate_other_path_vars_t(A_p_dist, P_t_generated, w_t_generated, params):
	A_p_dist_masked = A_p_dist.copy()
	if MASK:
		A_p_dist_masked[:,:MASK_LO,:] = 0 
		A_p_dist_masked[:,MASK_HI:,:] = 0 
	P_Y_ss,L_ss,Aagg_ss,mu_bar_ss,alpha_ss,Ups_C_ss = params
	p_grid_dist = np.exp(np.log(1) + dlogp_mesh)
	A_grid_dist = A_mesh.copy()
	B_grid_dist = np.exp(dlogB_mesh)
	# Temporary y, l, and sale variables (not including Y)
	y_grid_dist = demand_curve(p_grid_dist/P_t_generated * 1/B_grid_dist)
	l_grid_dist = y_grid_dist / A_grid_dist
	mu_grid_dist = p_grid_dist * A_grid_dist / w_t_generated
	sale_grid_dist = p_grid_dist * y_grid_dist
	mu_bar_t_generated = (sale_grid_dist*A_p_dist_masked).sum() / (sale_grid_dist/mu_grid_dist * A_p_dist_masked).sum()
	Y_t_generated = (((l_grid_dist*A_p_dist_masked).sum())*(alpha_ss*mu_bar_t_generated)**(1/(inv_frisch+1)))**(-(inv_frisch+gamma)/(inv_frisch+1))
	# Now refresh L, P^Y
	y_grid_dist = Y_t_generated*demand_curve(p_grid_dist/P_t_generated  * 1/B_grid_dist)
	l_grid_dist = y_grid_dist / A_grid_dist
	sale_grid_dist = p_grid_dist * y_grid_dist
	L_t_generated = (l_grid_dist*A_p_dist_masked).sum()
	P_Y_t_generated = (sale_grid_dist*A_p_dist_masked).sum() / Y_t_generated
	mu_small_generated = (sale_grid_dist[:,:LARGE_CUTOFF,:]*A_p_dist_masked[:,:LARGE_CUTOFF,:]).sum()/(sale_grid_dist[:,:LARGE_CUTOFF,:]/mu_grid_dist[:,:LARGE_CUTOFF,:]*A_p_dist_masked[:,:LARGE_CUTOFF,:]).sum()
	mu_large_generated = (sale_grid_dist[:,LARGE_CUTOFF:,:]*A_p_dist_masked[:,LARGE_CUTOFF:,:]).sum()/(sale_grid_dist[:,LARGE_CUTOFF:,:]/mu_grid_dist[:,LARGE_CUTOFF:,:]*A_p_dist_masked[:,LARGE_CUTOFF:,:]).sum()
	return mu_bar_t_generated,Y_t_generated,L_t_generated,P_Y_t_generated,mu_small_generated,mu_large_generated

def calculate_other_path_vars(path_dist, P_generated, w_generated, params):
	T = len(path_dist)
	mu_bar_generated = np.zeros(T)
	Y_generated = np.zeros(T)
	L_generated = np.zeros(T)
	P_Y_generated = np.zeros(T)
	mu_small_generated = np.zeros(T)
	mu_large_generated = np.zeros(T)
	for t in range(T):
		mu_bar_generated[t],Y_generated[t],L_generated[t],P_Y_generated[t],mu_small_generated[t],mu_large_generated[t] = calculate_other_path_vars_t(path_dist[t], P_generated[t], w_generated[t], params)
	return mu_bar_generated,Y_generated,L_generated,P_Y_generated,mu_small_generated,mu_large_generated

def plot_mu_over_path(path_dist, w_generated, saveprefix=None):
	plt.close('all')
	A_grid_dist = A_mesh.copy()
	p_grid_dist = np.exp(np.log(1) + dlogp_mesh)
	mu_levels = np.zeros((n_A, len(w_generated)))
	for t in range(len(w_generated)):
		mu_grid_dist = p_grid_dist * A_grid_dist / w_generated[t]
		mu_levels[:,t] = (path_dist[t] * mu_grid_dist).sum(axis=2).sum(axis=0)/(path_dist[t].sum(axis=2).sum(axis=0))
	plt.plot(mu_levels.T)
	plt.xlabel('Period')
	plt.ylabel('Markup (by productivity level)')
	fig = plt.gcf()
	if saveprefix is not None:
		fig.savefig(saveprefix + '_markups_over_t.png', dpi=300, bbox_inches='tight')
	plt.close('all')
	plt.plot(mu_levels[MASK_LO:MASK_HI,:].T)
	plt.xlabel('Period')
	plt.ylabel('Markup (by productivity level)')
	fig = plt.gcf()
	if saveprefix is not None:
		fig.savefig(saveprefix + '_markups_masked_over_t.png', dpi=300, bbox_inches='tight')

def calculation_markup_passthrough_cov(ss_dist, policy_paths, P_ss, Y_ss, w_ss, params):
	P_Y_ss,L_ss,Aagg_ss,mu_bar_ss,alpha_ss,Ups_C_ss = params
	policy_ind_path = np.searchsorted(dlogp_grid, policy_paths)
	last_dist = ss_dist.copy()
	next_dist = np.zeros(ss_dist.shape)
	p_grid_dist = np.exp(np.log(1) + dlogp_mesh)
	A_grid_dist = A_mesh.copy()
	B_grid_dist = np.exp(dlogB_mesh)
	y_grid_dist = Y_ss*demand_curve(p_grid_dist/P_ss  * 1/B_grid_dist)
	mu_grid_dist = p_grid_dist * A_grid_dist / w_ss
	sale_grid_dist = (p_grid_dist * y_grid_dist) / P_Y_ss
	all_passthroughs = []
	for a in range(n_A):
		for b in range(n_dlogB):
			p_visit = np.where(last_dist[:,a,b]>0)[0]
			for p in p_visit:
				a_trans_grid = np.where(Pi[a,:] > 0)[0]
				b_trans_grid = np.where(Pi_dlogB[b,:] > 0)[0]
				for a_trans in a_trans_grid:
					for b_trans in b_trans_grid:
						prob = Pi[a,a_trans] * Pi_dlogB[b,b_trans]
						old_price = dlogp_grid[p]
						new_price = dlogp_grid[policy_ind_path[0][p,a_trans,b_trans]]
						all_passthroughs += [pd.Series({
							'a_old_ind': a,
							'b_old_ind': b,
							'a_new_ind': a_trans,
							'b_new_ind': b_trans,
							'dlogp_old': old_price,
							'dlogp_new': new_price,
							'price_change' : new_price - old_price,
							'prob' : prob,
							'init_density' : last_dist[p,a,b],
							'density_frac' : last_dist[p,a,b]*prob,
							'init_sales' : sale_grid_dist[p,a,b],
							'density_sales' : sale_grid_dist[p,a,b]*prob,
							'init_markup' : mu_grid_dist[p,a,b],
							'inv_markup' : 1/mu_grid_dist[p,a,b]
						})]
	realized_passthough_df = pd.concat(all_passthroughs, axis=1).T
	cov = sp.weighted_cov(realized_passthough_df['inv_markup'], realized_passthough_df['price_change'], realized_passthough_df['density_sales'])
	cov_by_A = np.zeros(n_A)
	for a in range(n_A):
		realized_passthough_df_a = realized_passthough_df.loc[realized_passthough_df['a_old_ind']==a]
		cov_by_A[a] = sp.weighted_cov(realized_passthough_df_a['inv_markup'], realized_passthough_df_a['price_change'], realized_passthough_df_a['density_sales'])
	E_cov_by_A = sp.weighted_exp(cov_by_A, realized_passthough_df.groupby('a_old_ind')['density_sales'].sum())
	realized_passthough_df['weighted_price_change'] = realized_passthough_df['price_change'] * realized_passthough_df['density_sales']
	realized_passthough_df['weighted_inv_markup'] = realized_passthough_df['inv_markup'] * realized_passthough_df['density_sales']
	E_by_A = realized_passthough_df.groupby('a_old_ind').agg({
		'density_sales' : 'sum',
		'weighted_price_change' : 'sum',
		'weighted_inv_markup' : 'sum'
	})
	cov_across_A = sp.weighted_cov(E_by_A['weighted_inv_markup']/E_by_A['density_sales'], E_by_A['weighted_price_change']/E_by_A['density_sales'], E_by_A['density_sales'])
	return cov,E_cov_by_A,cov_across_A


def calculate_frac_change(ss_dist, policy_ind_path, policy_ind_ss, T, BURN_PERIOD):
	path_dist_remaining = np.zeros(np.insert(ss_dist.shape, 0, T))
	last_dist = ss_dist.copy()
	Pi_full = np.kron(Pi, Pi_dlogB)
	for t in range(T):
		if t > 0:
			last_dist = path_dist_remaining[t-1]
		new_dist_pold = np.reshape(np.reshape(last_dist, (n_dlogp, n_A*n_dlogB), order='C') @ Pi_full, (n_dlogp, n_A, n_dlogB), order='C')
		for a in range(n_A):
			for b in range(n_dlogB):
				p_visit = np.where(new_dist_pold[:,a,b]>0)[0]
				for p in p_visit:
					if t < BURN_PERIOD:
						if policy_ind_ss[p,a,b]==p:
							path_dist_remaining[t][p,a,b] += new_dist_pold[p,a,b]
					else:
						if policy_ind_path[t - BURN_PERIOD][p,a,b]==p:
							path_dist_remaining[t][p,a,b] += new_dist_pold[p,a,b]
	frac_change_all = 1 - path_dist_remaining.sum(axis=3).sum(axis=2).sum(axis=1) / ss_dist.sum()
	frac_change_small = 1 - path_dist_remaining[:,:,:LARGE_CUTOFF,:].sum(axis=3).sum(axis=2).sum(axis=1) / ss_dist[:,:LARGE_CUTOFF,:].sum()
	frac_change_large = 1 - path_dist_remaining[:,:,LARGE_CUTOFF:,:].sum(axis=3).sum(axis=2).sum(axis=1) / ss_dist[:,LARGE_CUTOFF:,:].sum()
	return frac_change_all,frac_change_small,frac_change_large

# def calculate_frac_change_quintile_v2(ss_dist, policy_ind_path, policy_ind_ss, T, BURN_PERIOD):
# 	path_dist_remaining = np.zeros(np.insert(ss_dist.shape, 0, T))
# 	last_dist = ss_dist.copy()
# 	Pi_full = np.kron(Pi, Pi_dlogB)
# 	mass_moved_by_n_A = np.zeros((n_A, T))
# 	weighted_avg_price_change = np.zeros((n_A, T))
# 	for t in range(T):
# 		if t > 0:
# 			last_dist = path_dist_remaining[t-1]
# 		new_dist_pold = np.reshape(np.reshape(last_dist, (n_dlogp, n_A*n_dlogB), order='C') @ Pi_full, (n_dlogp, n_A, n_dlogB), order='C')
# 		for a in range(n_A):
# 			for b in range(n_dlogB):
# 				p_visit = np.where(new_dist_pold[:,a,b]>0)[0]
# 				for p in p_visit:
# 					if t < BURN_PERIOD:
# 						if policy_ind_ss[p,a,b]==p:
# 							path_dist_remaining[t][p,a,b] += new_dist_pold[p,a,b]
# 						# else:
# 						# 	mass_moved_by_n_A[a,t] += new_dist_pold[p,a,b]
# 						# 	weighted_avg_price_change[a,t] += new_dist_pold[p,a,b]*(np.abs(dlogp_grid[p_visit] - dlogp_grid[p]))
# 					else:
# 						if policy_ind_path[t - BURN_PERIOD][p,a,b]==p:
# 							path_dist_remaining[t][p,a,b] += new_dist_pold[p,a,b]
# 						# else:
# 						# 	mass_moved_by_n_A[a,t] += new_dist_pold[p,a,b]
# 						# 	weighted_avg_price_change[a,t] += new_dist_pold[p,a,b]*(np.abs(dlogp_grid[policy_ind_path[t - BURN_PERIOD][p,a,b]] - dlogp_grid[p]))
# 		for a in range(n_A):
# 			for b in range(n_dlogB):
# 				new_index_visit = np.where(Pi_full[a + b*n_A]>0)[0]
# 				print(new_index_visit)
# 				for new_index in new_index_visit:
# 					new_a = new_index % n_A
# 					new_b = int(new_index / n_A)
# 					transition_prob = Pi_full[a + b*n_A, new_index]
# 					print(new_index, new_a, new_b, Pi_full[a + b*n_A, new_index])
# 					p_visit = np.where(new_dist_pold[:,new_a,new_b]>0)[0]
# 					for p in p_visit:
# 						if t < BURN_PERIOD:
# 							if policy_ind_ss[p,new_a,new_b]!=p:
# 								mass_moved_by_n_A[a,t] += last_dist[p,a,b]*transition_prob
# 								weighted_avg_price_change[a,t] += last_dist[p,a,b]*transition_prob*(np.abs(dlogp_grid[policy_ind_ss[t - BURN_PERIOD][p,new_a,new_b]] - dlogp_grid[p]))
# 						else:
# 							if policy_ind_path[t - BURN_PERIOD][p,new_a,new_b]!=p:
# 								mass_moved_by_n_A[a,t] += last_dist[p,a,b]
# 								weighted_avg_price_change[a,t] += last_dist[p,a,b]*transition_prob*(np.abs(dlogp_grid[policy_ind_path[t - BURN_PERIOD][p,new_a,new_b]] - dlogp_grid[p]))
# 	N_GROUPS = 5
# 	cutoffs = np.array([int(i) for i in np.linspace(0, n_A, N_GROUPS+1)])
# 	if MASK:
# 		cutoffs = np.array([int(i) for i in np.linspace(MASK_LO, MASK_HI, N_GROUPS+1)])
# 	frac_change_arr = np.zeros((N_GROUPS,T))
# 	price_change_size_arr = np.zeros((N_GROUPS,T))
# 	for i in range(N_GROUPS):
# 		#print(i, 1 - path_dist_remaining[:,:,cutoffs[i]:cutoffs[i+1],:].sum(axis=3).sum(axis=2).sum(axis=1) / ss_dist[:,cutoffs[i]:cutoffs[i+1],:].sum())
# 		frac_change_arr[i,:] = 1 - path_dist_remaining[:,:,cutoffs[i]:cutoffs[i+1],:].sum(axis=3).sum(axis=2).sum(axis=1) / ss_dist[:,cutoffs[i]:cutoffs[i+1],:].sum()
# 		for t in range(T):
# 			price_change_size_arr[i,t] = sum(sum(weighted_avg_price_change[cutoffs[i]:cutoffs[i+1], :t+1]))/sum(sum(mass_moved_by_n_A[cutoffs[i]:cutoffs[i+1], :t+1]))
# 		return frac_change_arr

def calculate_frac_change_quintile(ss_dist, policy_ind_path, policy_ind_ss, T, BURN_PERIOD):
	N_GROUPS = 5
	cutoffs = np.array([int(i) for i in np.linspace(0, n_A, N_GROUPS+1)])
	if MASK:
		cutoffs = np.array([int(i) for i in np.linspace(MASK_LO, MASK_HI, N_GROUPS+1)])
	frac_change_arr = np.zeros((N_GROUPS,T))
	price_change_size_arr = np.zeros((N_GROUPS,T))
	price_change_size_abs_arr = np.zeros((N_GROUPS,T))
	for i in range(N_GROUPS):
		frac_change_arr[i,:],price_change_size_arr[i,:],price_change_size_abs_arr[i,:] = calculate_price_change_by_initial_prod(ss_dist, policy_ind_path, policy_ind_ss, T, BURN_PERIOD, list(range(cutoffs[i], cutoffs[i+1])))
	print('Fraction change: ', pd.DataFrame(frac_change_arr).to_latex(float_format='%0.3f'))
	print('Price change size: ', pd.DataFrame(price_change_size_arr).to_latex(float_format='%0.4f'))
	print('Price change size abs: ', pd.DataFrame(price_change_size_abs_arr).to_latex(float_format='%0.4f'))
	print('Realized pass-through: ', pd.DataFrame(frac_change_arr * price_change_size_arr).to_latex(float_format='%0.4f'))

def calculate_price_change_by_initial_prod(ss_dist, policy_ind_path, policy_ind_ss, T, BURN_PERIOD, a_index_choose):
	path_dist_remaining = np.zeros(np.insert(ss_dist.shape, 0, T))
	last_dist = ss_dist.copy()
	for i in range(n_A):
		if i not in a_index_choose:
			last_dist[:,i,:] = 0
	# last_dist[:,:a_index_choose,:] = 0
	# last_dist[:,a_index_choose+1:,:] = 0
	INITIAL_MASS = 100000
	last_dist = INITIAL_MASS / last_dist.sum() * last_dist
	Pi_full = np.kron(Pi, Pi_dlogB)
	mass_moved_by_n_A = np.zeros(T)
	weighted_avg_price_change = np.zeros(T)
	weighted_avg_price_change_abs = np.zeros(T)
	for t in range(T):
		if t > 0:
			last_dist = path_dist_remaining[t-1]
		new_dist_pold = np.reshape(np.reshape(last_dist, (n_dlogp, n_A*n_dlogB), order='C') @ Pi_full, (n_dlogp, n_A, n_dlogB), order='C')
		for a in range(n_A):
			for b in range(n_dlogB):
				p_visit = np.where(new_dist_pold[:,a,b]>0)[0]
				for p in p_visit:
					if t < BURN_PERIOD:
						if policy_ind_ss[p,a,b]==p:
							path_dist_remaining[t][p,a,b] += new_dist_pold[p,a,b]
						else:
							mass_moved_by_n_A[t] += new_dist_pold[p,a,b]
							weighted_avg_price_change_abs[t] += new_dist_pold[p,a,b]*(np.abs(dlogp_grid[p_visit] - dlogp_grid[p]))
							weighted_avg_price_change[t] += new_dist_pold[p,a,b]*((dlogp_grid[p_visit] - dlogp_grid[p]))
					else:
						if policy_ind_path[t - BURN_PERIOD][p,a,b]==p:
							path_dist_remaining[t][p,a,b] += new_dist_pold[p,a,b]
						else:
							mass_moved_by_n_A[t] += new_dist_pold[p,a,b]
							weighted_avg_price_change_abs[t] += new_dist_pold[p,a,b]*(np.abs(dlogp_grid[policy_ind_path[t - BURN_PERIOD][p,a,b]] - dlogp_grid[p]))
							weighted_avg_price_change[t] += new_dist_pold[p,a,b]*((dlogp_grid[policy_ind_path[t - BURN_PERIOD][p,a,b]] - dlogp_grid[p]))
	frac_change_arr = np.zeros(T)
	price_change_size_abs_arr = np.zeros(T)
	price_change_size_arr = np.zeros(T)
	for t in range(T):
		frac_change_arr[t] = mass_moved_by_n_A[:t+1].sum() / INITIAL_MASS
		price_change_size_abs_arr[t] = weighted_avg_price_change_abs[:t+1].sum() / mass_moved_by_n_A[:t+1].sum()
		price_change_size_arr[t] = weighted_avg_price_change[:t+1].sum() / mass_moved_by_n_A[:t+1].sum()
	return frac_change_arr,price_change_size_arr,price_change_size_abs_arr

def run_simulation_with_paths_dist(policy_paths, policy_ss, ss_dist, P_ss, w_ss, Y_ss, params, dlogP_path, dlogw_path, dlogY_path, ss_dist_end=None, final_shock=0, saveprefix=None):
	P_Y_ss,L_ss,Aagg_ss,mu_bar_ss,alpha_ss,Ups_C_ss = params
	BURN_PERIOD = 0
	SHOCK_PERIOD = value_fn_paths.shape[0]
	T = BURN_PERIOD + SHOCK_PERIOD
	#path_dist = [np.zeros(ss_dist.shape) for i in range(T)]
	path_dist = np.zeros(np.insert(ss_dist.shape, 0, T))
	last_dist = ss_dist.copy()
	Pi_full = np.kron(Pi, Pi_dlogB)
	policy_ind_ss = np.searchsorted(dlogp_grid, policy_ss)
	policy_ind_path = np.searchsorted(dlogp_grid, policy_paths)
	for t in range(T):
		if t > 0:
			last_dist = path_dist[t-1]
		new_dist_pold = np.reshape(np.reshape(last_dist, (n_dlogp, n_A*n_dlogB), order='C') @ Pi_full, (n_dlogp, n_A, n_dlogB), order='C')
		for a in range(n_A):
			for b in range(n_dlogB):
				p_visit = np.where(new_dist_pold[:,a,b]>0)[0]
				for p in p_visit:
					if t < BURN_PERIOD:
						path_dist[t][policy_ind_ss[p,a,b],a,b] += new_dist_pold[p,a,b]
					else:
						path_dist[t][policy_ind_path[t - BURN_PERIOD][p,a,b],a,b] += new_dist_pold[p,a,b]
		# for p in range(n_dlogp):
		# 	for a in range(n_A):
		# 		for b in range(n_dlogB):
		# 			if t < BURN_PERIOD:
		# 				path_dist[t][policy_ind_ss[p,a,b],a,b] += new_dist_pold[p,a,b]
		# 			else:
		# 				path_dist[t][policy_ind_path[t - BURN_PERIOD][p,a,b],a,b] += new_dist_pold[p,a,b]
		# Path of wages (given)
	dlogw_generated = np.append(np.zeros(BURN_PERIOD), dlogw_path)
	w_generated = np.exp(np.log(w_ss) + dlogw_generated)
	# Firm variables
	P_generated = solve_for_P_dist_path(path_dist, P_ss, Ups_C_ss)
	mu_bar_generated,Y_generated,L_generated,P_Y_generated,mu_small_generated,mu_large_generated = calculate_other_path_vars(path_dist, P_generated, w_generated, params)
	_,_,_,_,mu_small_ss,mu_large_ss = calculate_other_path_vars_t(ss_dist_end, np.exp(np.log(P_ss)+final_shock), np.exp(np.log(w_ss)+final_shock), params)
	print('dlogA: ', np.log(Y_generated / L_generated) - np.log(Aagg_ss))
	dlogP = np.log(P_generated) - np.log(P_ss)
	dlogY = np.log(Y_generated) - np.log(Y_ss)
	dlogmu_bar = np.log(mu_bar_generated) - np.log(mu_bar_ss)
	dlogL = np.log(L_generated) - np.log(L_ss)
	dlogP_Y = np.log(P_Y_generated) - np.log(P_Y_ss)
	dlogmu_small = np.log(mu_small_generated) - np.log(mu_small_ss)
	dlogmu_large = np.log(mu_large_generated) - np.log(mu_large_ss)
	if ss_dist_end is None:
		ss_dist_end = ss_dist.copy()
	# frac_change = np.array([1 - ((np.max(path_dist_t - ss_dist_end, 0)).sum() / ss_dist_end.sum()) for path_dist_t in path_dist])
	# frac_change_small = np.array([1 - ((np.max(path_dist_t[:,:LARGE_CUTOFF,:] - ss_dist_end[:,:LARGE_CUTOFF,:], 0)).sum() / ss_dist_end[:,:LARGE_CUTOFF,:].sum()) for path_dist_t in path_dist])
	# frac_change_large = np.array([1 - ((np.max(path_dist_t[:,LARGE_CUTOFF:,:] - ss_dist_end[:,LARGE_CUTOFF:,:], 0)).sum() / ss_dist_end[:,LARGE_CUTOFF:,:].sum()) for path_dist_t in path_dist])
	# if MASK:
	# 	frac_change = np.array([1 - ((np.max(path_dist_t[:,MASK_LO:MASK_HI,:] - ss_dist_end[:,MASK_LO:MASK_HI,:], 0)).sum() / ss_dist_end[:,MASK_LO:MASK_HI,:].sum()) for path_dist_t in path_dist])
	# 	frac_change_small = np.array([1 - ((np.max(path_dist_t[:,MASK_LO:LARGE_CUTOFF,:] - ss_dist_end[:,MASK_LO:LARGE_CUTOFF,:], 0)).sum() / ss_dist_end[:,MASK_LO:LARGE_CUTOFF,:].sum()) for path_dist_t in path_dist])
	# 	frac_change_large = np.array([1 - ((np.max(path_dist_t[:,LARGE_CUTOFF:MASK_HI,:] - ss_dist_end[:,LARGE_CUTOFF:MASK_HI,:], 0)).sum() / ss_dist_end[:,LARGE_CUTOFF:MASK_HI,:].sum()) for path_dist_t in path_dist])
	frac_change,frac_change_small,frac_change_large = calculate_frac_change(ss_dist, policy_ind_path, policy_ind_ss, T, BURN_PERIOD)
	calculate_frac_change_quintile(ss_dist, policy_ind_path, policy_ind_ss, T, BURN_PERIOD)
	plot_mu_over_path(path_dist, w_generated, saveprefix)
	return dlogP[BURN_PERIOD:],dlogY[BURN_PERIOD:],dlogmu_bar[BURN_PERIOD:],dlogL[BURN_PERIOD:],dlogP_Y[BURN_PERIOD:],dlogmu_small[BURN_PERIOD:],dlogmu_large[BURN_PERIOD:],frac_change[BURN_PERIOD:],frac_change_small[BURN_PERIOD:],frac_change_large[BURN_PERIOD:]

def plot_results(dlogw_path, res, final_shock=0, color='blue', rescale=1, savename=None, tmax=15):
	dlogP_path,dlogY_path,dlogmu_bar_path,dlogL_path,dlogP_Y_path,dlogmu_small,dlogmu_large,frac_change,frac_change_small,frac_change_large = res
	dlogA_path = dlogY_path - dlogL_path
	dlogEmu_path = dlogP_Y_path - dlogw_path
	dlogM_path = dlogP_Y_path + dlogY_path
	#var_list = [dlogw_path, dlogM_path, dlogP_path, dlogP_Y_path, dlogY_path, dlogL_path, dlogA_path, dlogmu_bar_path, frac_change_small, frac_change_large, dlogmu_small, dlogmu_large]
	#titles = ['dlogw', 'dlogM', 'dlogP', 'dlogP_Y', 'dlogY', 'dlogL', 'dlogA', 'dlogmu_bar', '% price change (small)', '% price change (large)', 'dlog mu (small)', 'dlog mu (large)']
	var_list = [dlogw_path, dlogM_path, dlogP_path, dlogP_Y_path, dlogY_path, dlogL_path, dlogA_path, dlogmu_bar_path]
	titles = ['dlogw', 'dlogM', 'dlogP', 'dlogP_Y', 'dlogY', 'dlogL', 'dlogA', 'dlogmu_bar']
	fig,axs = plt.subplots(int((len(var_list) + 1)/2),2,figsize=(8,1.5*int((len(var_list) + 1)/2)))
	axs = axs.flatten()
	for plot_ind in range(len(var_list)):
		axs[plot_ind].plot(var_list[plot_ind][:tmax] * rescale, marker='o', color=color, fillstyle='none')
		axs[plot_ind].set_title(titles[plot_ind])
	plt.tight_layout(rect=[0, 0, 1, 0.95])
	if savename is not None:
		fig.savefig(savename, dpi=300, bbox_inches='tight')
		plt.close(fig)

def plot_value_fn_forind(ind, value_fn, value_fn_paths, ss_dist, savename=None):
	fig,axs = plt.subplots(1,2,figsize=(8,4))
	axs[0].plot(dlogp_grid, value_fn[:,ind,np.searchsorted(dlogB_grid,0)], label='Steady state')
	axs[0].plot(dlogp_grid, value_fn_paths[0][:,ind,np.searchsorted(dlogB_grid,0)], label='$t=0$')
	axs[0].legend()
	axs[0].set_title('Value functions')
	axs[1].plot(dlogp_grid, ss_dist[:,ind,:].sum(axis=1)/ss_dist[:,ind,:].sum())
	axs[1].set_title('Steady-state density')
	plt.tight_layout()
	if savename is not None:
		fig.savefig(savename, dpi=300, bbox_inches='tight')
		plt.close(fig)

def plot_value_fns(value_fn, value_fn_paths, ss_dist, saveprefix=None):
	plot_ptls = [0, 10, 20, 50, 80, 90, 100]
	plot_inds = [int(ptl / 100 * (n_A-1)) for ptl in plot_ptls]
	if saveprefix is None:
		for ind in plot_inds:
			plot_value_fn_forind(ind, value_fn, value_fn_paths, ss_dist)
	else:
		for ind in plot_inds:
			plot_value_fn_forind(ind, value_fn, value_fn_paths, ss_dist, savename=saveprefix + '_valuefn_firm' + str(ind) + '.png')

def plot_ss_mu_dist(ss_dist, w_fordist, savename=None):
	plt.close('all')
	A_grid_dist = A_mesh.copy()
	p_grid_dist = np.exp(np.log(1) + dlogp_mesh)
	mu_grid_dist = p_grid_dist * A_grid_dist / w_fordist
	plt.plot(np.linspace(0,1,n_A), (ss_dist * mu_grid_dist).sum(axis=2).sum(axis=0)/(ss_dist.sum(axis=2).sum(axis=0)))
	plt.xlabel('Firm productivity space')
	plt.ylabel('Steady-state avg. markup')
	fig = plt.gcf()
	if savename is not None:
		fig.savefig(savename + '.png', dpi=300, bbox_inches='tight')

def calculate_prod_shock_sd():
	# Distribution of productivity changes
	prod_change = np.tile(np.log(A_grid), (n_A,1)) - np.tile(np.log(A_grid), (n_A,1)).T
	density_change = Pi @ np.diag(wt)
	density = pd.Series(prod_change.ravel()).value_counts().sort_index() * 0
	for i in range(n_A):
		for j in range(n_A):
			density.loc[prod_change[i,j]] += density_change[i,j]
	density = density.reset_index()
	prod_shock_var = np.average((density['index']-np.average(density['index'],weights=density[0]))**2,weights=density[0])
	prod_shock_sd = np.sqrt(prod_shock_var)
	print('Standard deviation of productivity shocks: ', prod_shock_sd)
	return prod_shock_sd

def compare_prod_dist_to_normal():
	dlogA_grid_mid = dlogA_grid - dlogA_grid[int(n_A/2)]
	dlogA_grid_midpoints = (dlogA_grid_mid[:-1] + dlogA_grid_mid[1:])/2
	transition_mid = Pi[:, int(n_A/2)]
	transition_cdf = transition_mid.cumsum()
	stdev = calculate_prod_shock_sd()
	norm_cdf = norm.cdf(dlogA_grid_midpoints, 0, stdev)
	norm_pdf = np.diff(norm_cdf)
	plt.plot(dlogA_grid_mid, transition_mid, label='Fat-tailed shocks')
	plt.plot(dlogA_grid_mid[1:-1], norm_pdf, ls='--', color='black', label='Normal')
	plt.legend()
	plt.xlim(dlogA_grid_mid[min(np.where(transition_mid > 0)[0]) - 2], dlogA_grid_mid[max(np.where(transition_mid > 0)[0]) + 2])
	plt.xlabel('$d \\log A$')
	plt.ylabel('Density')

def plot_price_changes(policy_ss, ss_dist, saveprefix=None):
	# Distribution of price changes
	price_change = np.tile(dlogp_grid, (len(dlogp_grid),1)) - np.tile(dlogp_grid, (len(dlogp_grid),1)).T
	price_change_density = pd.Series(price_change.ravel()).value_counts().sort_index() * 0
	for a in range(n_A):
		for b in range(n_dlogB):
			for i in np.where(ss_dist[:,a,b]>0)[0]:
				for a_new in range(n_A):
					for b_new in range(n_dlogB):
						price_change_atpt = policy_ss[i,a_new,b_new] - dlogp_grid[i]
						price_change_density.loc[price_change_atpt] += ss_dist[i,a,b] * Pi[a,a_new] * Pi_dlogB[b,b_new]
	percent_changeprice = 1 - price_change_density.loc[0] / price_change_density.sum()
	print('Percent of firms changing price: ', percent_changeprice)
	price_change_density = price_change_density[price_change_density > 0].reset_index()
	price_change_density_nonzero = price_change_density.loc[price_change_density['index']!=0]
	bins_increment = (price_change_density_nonzero['index'].max()-price_change_density_nonzero['index'].min()) / 30
	bins_increment = np.round(bins_increment, 3)
	bins = np.arange(bins_increment*np.floor(price_change_density_nonzero['index'].min()/bins_increment), bins_increment*np.ceil(price_change_density_nonzero['index'].max()/bins_increment), bins_increment)
	price_change_density_nonzero['bin'] = pd.cut(price_change_density_nonzero['index'], bins).apply(lambda x: x.right)
	binned_price_change_density = price_change_density_nonzero.groupby('bin')[0].sum()/(price_change_density_nonzero[0].sum())
	binned_price_change_density.plot(kind='bar', width=1.0)
	plt.xlabel('$d\\log p$')
	plt.ylabel('Density')
	if saveprefix is not None:
		plt.tight_layout()
		plt.gcf().savefig(saveprefix + '.png', dpi=300, bbox_inches='tight')
		tikzplotlib.clean_figure()
		tikzplotlib.save(saveprefix + '.tex', extra_axis_parameters=['PlotStyle'])
		plt.close(plt.gcf())
	price_change_density_nonzero['abs'] = np.abs(price_change_density_nonzero['index'])
	abs_change_density = price_change_density_nonzero.groupby(['abs'])[0].sum() / price_change_density_nonzero[0].sum()
	abs_change_density = abs_change_density.cumsum()
	print('Moments of price change dist: ', interp1d(abs_change_density, abs_change_density.index)([0.25, 0.5, 0.75]))
	return percent_changeprice

def get_FPA_dist(ss_dist, policy_ss):
	Pi_full = np.kron(Pi, Pi_dlogB)
	ss_dist_next = np.reshape(np.reshape(ss_dist, (n_dlogp, n_A*n_dlogB), order='C') @ Pi_full, (n_dlogp, n_A, n_dlogB), order='C')
	same_price_policy = (policy_ss==dlogp_mesh).astype(int)
	ss_dist_next_samepriceonly = ss_dist_next * same_price_policy
	fpa_all = 1 - (ss_dist_next_samepriceonly.sum() / ss_dist.sum())
	fpa_arr = np.zeros(n_A)
	for i in range(n_A):
		ss_dist_mask = np.zeros(ss_dist.shape)
		ss_dist_mask[:,i,:] = ss_dist[:,i,:]
		ss_dist_mask_next = np.reshape(np.reshape(ss_dist_mask, (n_dlogp, n_A*n_dlogB), order='C') @ Pi_full, (n_dlogp, n_A, n_dlogB), order='C')
		ss_dist_mask_next_samepriceonly = ss_dist_mask_next * same_price_policy
		fpa_arr[i] = 1 - ss_dist_mask_next_samepriceonly.sum()/ss_dist_mask.sum()
	print('Percent of all firms changing price: ', fpa_all)
	print('Percent of unmasked firms changing price: ', fpa_arr[MASK_LO:MASK_HI].mean())
	return fpa_all,fpa_arr

def plot_FPA_dist(ss_dist, policy_ss, savename=None):
	plt.close('all')
	fpa_all,fpa_arr = get_FPA_dist(ss_dist, policy_ss)
	plt.plot(np.linspace(0,1,n_A), fpa_arr)
	plt.xlabel('Firm productivity space')
	plt.ylabel('Frequency of price adjustment')
	fig = plt.gcf()
	if savename is not None:
		fig.savefig(savename + '.png', dpi=300, bbox_inches='tight')
		plt.close(fig)



P_ss = 1
w_ss = 1 
Y_ss = 1

SHOCK_PERIOD = 50

menu_cost_arr = [0.02]
inv_frisch_arr = [5, 0]
final_shock_arr = [0.04]
ind_experiment = 0

# # Calculate stdev of productivity shocks
n_experiments = len(menu_cost_arr) * len(inv_frisch_arr) * len(final_shock_arr)
prod_sd_vals = np.zeros(n_experiments)
pricechange_vals = np.zeros(n_experiments)

Pi = Pi_dlogA_symmetric

for i_m,m in enumerate(menu_cost_arr):
	for i_f,f in enumerate(inv_frisch_arr):
		for i_s,s in enumerate(final_shock_arr):
			# Reset parameters to new values
			menu_cost = m
			inv_frisch = f
			# Reset shock path
			final_shock = s
			dlogw_path_hit_pos = np.ones(SHOCK_PERIOD) * np.abs(final_shock)
			dlogw_path_arr = [dlogw_path_hit_pos]
			path_labels = ['hit_pos']
			#folder_name = 'menu_cost_wage_paths/CES_normunifprod' + str(int(poisson_prob*100)) + '_menu' + str(int(menu_cost*100)) + '_frisch' + str(int(frisch*100)) + '_shocksize' + str(int(final_shock * 100))
			folder_name = 'menu_cost_wage_paths/Kimball_symmask_v37_menu' + str(int(menu_cost*100)) + '_invfrisch' + str(int(inv_frisch)) + '_shocksize' + str(int(final_shock * 100))
			if not os.path.exists(folder_name):
				os.mkdir(folder_name)
			# Solve for steady-state policy,value,distribution
			policy_ss,value_fn,ss_dist,params = generate_ss_values_dist(P_ss, w_ss, Y_ss)
			plot_ss_mu_dist(ss_dist, w_ss, savename=folder_name + '/markups_by_firmindex')
			plot_FPA_dist(ss_dist, policy_ss, savename=folder_name + '/fpa_by_firmindex')
			plot_price_changes(policy_ss, ss_dist, savename=folder_name + '/price_change_dist')
			ind_experiment += 1
			policy_end_pos,value_end_pos,ss_dist_end_pos,params_end_pos = generate_ss_values_dist(np.exp(np.log(P_ss)+np.abs(final_shock)), np.exp(np.log(w_ss)+np.abs(final_shock)), Y_ss)
			for dlogw_path,w_label in zip(dlogw_path_arr, path_labels):
				NUM_PERIODS = len(dlogw_path)
				dlogP_path_t = np.zeros((MIT_shock_iter_max, NUM_PERIODS))
				dlogY_path_t = np.zeros((MIT_shock_iter_max, NUM_PERIODS))
				for i in range(MIT_shock_iter_max-1):
					print(i, ' -- P: ', dlogP_path_t[i][:6])
					print(i, ' -- Y: ', dlogY_path_t[i][:6])
					dlogP_path_try = dlogP_path_t[i]
					dlogY_path_try = dlogY_path_t[i]
					if ('perm' in w_label) or ('hit' in w_label):
						if 'pos' in w_label:
							value_fn_paths,policy_paths = backward_iterate_valuefn(value_end_pos, policy_end_pos, dlogP_path_try, dlogw_path, dlogY_path_try, P_ss, w_ss, Y_ss)
							res = run_simulation_with_paths_dist(policy_paths, policy_ss, ss_dist, P_ss, w_ss, Y_ss, params, dlogP_path_try, dlogw_path, dlogY_path_try, ss_dist_end=ss_dist_end_pos, final_shock=np.abs(final_shock), saveprefix=folder_name + '/' + w_label + '_iter' + str(i))
						else:
							value_fn_paths,policy_paths = backward_iterate_valuefn(value_end_neg, policy_end_neg, dlogP_path_try, dlogw_path, dlogY_path_try, P_ss, w_ss, Y_ss)
							res = run_simulation_with_paths_dist(policy_paths, policy_ss, ss_dist, P_ss, w_ss, Y_ss, params, dlogP_path_try, dlogw_path, dlogY_path_try, ss_dist_end=ss_dist_end_neg, final_shock=-np.abs(final_shock))
					else:
						value_fn_paths,policy_paths = backward_iterate_valuefn(value_fn, policy_ss, dlogP_path_try, dlogw_path, dlogY_path_try, P_ss, w_ss, Y_ss)
						res = run_simulation_with_paths_dist(policy_paths, policy_ss, ss_dist, P_ss, w_ss, Y_ss, params, dlogP_path_try, dlogw_path, dlogY_path_try)
					dlogP_path_t[i+1],dlogY_path_t[i+1],_,_,_,_,_,_,_,_ = res
					np.save(folder_name + '/final_res', np.array(res))
					cov_results = np.array(calculation_markup_passthrough_cov(ss_dist, policy_paths, P_ss, Y_ss, w_ss, params))
					np.save(folder_name + '/cov_breakdown_' + str(i), cov_results)
					if 'rebound' not in w_label:
						plot_results(dlogw_path, res, color='blue', rescale=100, tmax=len(dlogw_path), savename=folder_name + '/' + w_label + '_iter' + str(i) + '.png')
						plot_results(dlogw_path, res, color='blue', rescale=100, tmax=15, savename=folder_name + '/' + w_label + '_iter' + str(i) + '_15.png')
						#plot_value_fns(value_fn, value_fn_paths, ss_dist, saveprefix=folder_name + '/' + w_label + '_iter' + str(i))
					else:
						res_trim = [r[SHOCK_PERIOD:] for r in res]
						plot_results(dlogw_path[SHOCK_PERIOD:], res_trim, color='blue', rescale=100, tmax=len(dlogw_path[SHOCK_PERIOD:]), savename=folder_name + '/' + w_label + '_iter' + str(i) + '.png')
						plot_results(dlogw_path[SHOCK_PERIOD:], res_trim, color='blue', rescale=100, tmax=15, savename=folder_name + '/' + w_label + '_iter' + str(i) + '_15.png')
						#plot_value_fns(value_fn, value_fn_paths[SHOCK_PERIOD:], ss_dist, saveprefix=folder_name + '/' + w_label + '_iter' + str(i))
					if max([(dlogP_path_t[i+1] - dlogP_path_t[i]).max(), (dlogY_path_t[i+1] - dlogY_path_t[i]).max()]) < MIT_iter_thresh:
						print('Converged!')
						break;


menu_cost = menu_cost_arr[0]
inv_frisch = inv_frisch_arr[0]
# Reset shock path
final_shock = final_shock_arr[0]
dlogw_path_hit_pos = np.ones(SHOCK_PERIOD) * np.abs(final_shock)
dlogw_path_arr = [dlogw_path_hit_pos]
path_labels = ['hit_pos']
#folder_name = 'menu_cost_wage_paths/CES_normunifprod' + str(int(poisson_prob*100)) + '_menu' + str(int(menu_cost*100)) + '_frisch' + str(int(frisch*100)) + '_shocksize' + str(int(final_shock * 100))
# folder_name = 'menu_cost_wage_paths/Kimball_symmask_v37_menu' + str(int(menu_cost*100)) + '_invfrisch' + str(int(inv_frisch)) + '_shocksize' + str(int(final_shock * 100))
# if not os.path.exists(folder_name):
# 	os.mkdir(folder_name)
# Solve for steady-state policy,value,distribution
policy_ss,value_fn,ss_dist,params = generate_ss_values_dist(P_ss, w_ss, Y_ss)
# plot_ss_mu_dist(ss_dist, w_ss, savename=folder_name + '/markups_by_firmindex')
# plot_FPA_dist(ss_dist, policy_ss, savename=folder_name + '/fpa_by_firmindex')
# plot_price_changes(policy_ss, ss_dist, savename=folder_name + '/price_change_dist')
# ind_experiment += 1
policy_end_pos,value_end_pos,ss_dist_end_pos,params_end_pos = generate_ss_values_dist(np.exp(np.log(P_ss)+np.abs(final_shock)), np.exp(np.log(w_ss)+np.abs(final_shock)), Y_ss)
for dlogw_path,w_label in zip(dlogw_path_arr, path_labels):
	NUM_PERIODS = len(dlogw_path)
	dlogP_path_t = np.zeros((MIT_shock_iter_max, NUM_PERIODS))
	dlogY_path_t = np.zeros((MIT_shock_iter_max, NUM_PERIODS))
	for i in range(MIT_shock_iter_max-1):
		print(i, ' -- P: ', dlogP_path_t[i][:6])
		print(i, ' -- Y: ', dlogY_path_t[i][:6])
		dlogP_path_try = dlogP_path_t[i]
		dlogY_path_try = dlogY_path_t[i]
		if ('perm' in w_label) or ('hit' in w_label):
			if 'pos' in w_label:
				value_fn_paths,policy_paths = backward_iterate_valuefn(value_end_pos, policy_end_pos, dlogP_path_try, dlogw_path, dlogY_path_try, P_ss, w_ss, Y_ss)
				#res = run_simulation_with_paths_dist(policy_paths, policy_ss, ss_dist, P_ss, w_ss, Y_ss, params, dlogP_path_try, dlogw_path, dlogY_path_try, ss_dist_end=ss_dist_end_pos, final_shock=np.abs(final_shock), saveprefix=folder_name + '/' + w_label + '_iter' + str(i))
				res = run_simulation_with_paths_dist(policy_paths, policy_ss, ss_dist, P_ss, w_ss, Y_ss, params, dlogP_path_try, dlogw_path, dlogY_path_try, ss_dist_end=ss_dist_end_pos, final_shock=np.abs(final_shock), saveprefix=None)
			else:
				value_fn_paths,policy_paths = backward_iterate_valuefn(value_end_neg, policy_end_neg, dlogP_path_try, dlogw_path, dlogY_path_try, P_ss, w_ss, Y_ss)
				res = run_simulation_with_paths_dist(policy_paths, policy_ss, ss_dist, P_ss, w_ss, Y_ss, params, dlogP_path_try, dlogw_path, dlogY_path_try, ss_dist_end=ss_dist_end_neg, final_shock=-np.abs(final_shock), saveprefix=None)
		else:
			value_fn_paths,policy_paths = backward_iterate_valuefn(value_fn, policy_ss, dlogP_path_try, dlogw_path, dlogY_path_try, P_ss, w_ss, Y_ss)
			res = run_simulation_with_paths_dist(policy_paths, policy_ss, ss_dist, P_ss, w_ss, Y_ss, params, dlogP_path_try, dlogw_path, dlogY_path_try)
		dlogP_path_t[i+1],dlogY_path_t[i+1],_,_,_,_,_,_,_,_ = res
		#np.save(folder_name + '/final_res', np.array(res))
		cov_results = np.array(calculation_markup_passthrough_cov(ss_dist, policy_paths, P_ss, Y_ss, w_ss, params))
		#np.save(folder_name + '/cov_breakdown_' + str(i), cov_results)
		# if 'rebound' not in w_label:
		# 	plot_results(dlogw_path, res, color='blue', rescale=100, tmax=len(dlogw_path), savename=None)
		# 	plot_results(dlogw_path, res, color='blue', rescale=100, tmax=15, savename=folder_name + '/' + w_label + '_iter' + str(i) + '_15.png')
		# 	#plot_value_fns(value_fn, value_fn_paths, ss_dist, saveprefix=folder_name + '/' + w_label + '_iter' + str(i))
		# else:
		# 	res_trim = [r[SHOCK_PERIOD:] for r in res]
		# 	plot_results(dlogw_path[SHOCK_PERIOD:], res_trim, color='blue', rescale=100, tmax=len(dlogw_path[SHOCK_PERIOD:]), savename=folder_name + '/' + w_label + '_iter' + str(i) + '.png')
		# 	plot_results(dlogw_path[SHOCK_PERIOD:], res_trim, color='blue', rescale=100, tmax=15, savename=folder_name + '/' + w_label + '_iter' + str(i) + '_15.png')
			#plot_value_fns(value_fn, value_fn_paths[SHOCK_PERIOD:], ss_dist, saveprefix=folder_name + '/' + w_label + '_iter' + str(i))
		if max([(dlogP_path_t[i+1] - dlogP_path_t[i]).max(), (dlogY_path_t[i+1] - dlogY_path_t[i]).max()]) < MIT_iter_thresh:
			print('Converged!')
			break;

